/* Copyright 2013 Tobias Marschall
 *
 * This file is part of CLEVER.
 *
 * CLEVER is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CLEVER is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with CLEVER.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <boost/math/distributions.hpp>

#include "BamHelper.h"
#include "SumOfBinomials.h"

#include "Genotyper.h"

using namespace std;


// M F C
// -----
// 0 0 0  1 (not present)
// 0 0 1  0 (de novo)
// 0 0 2  0
// 0 1 0  0.5
// 0 1 1  0.5
// 0 1 2  0
// 0 2 0  0 
// 0 2 1  1
// 0 2 2  0
// -----
// 1 0 0  0.5
// 1 0 1  0.5
// 1 0 2  0
// 1 1 0  0.25
// 1 1 1  0.5
// 1 1 2  0.25
// 1 2 0  0
// 1 2 1  0.5
// 1 2 2  0.5
// -----
// 2 0 0  0
// 2 0 1  1
// 2 0 2  0
// 2 1 0  0
// 2 1 1  0.5
// 2 1 2  0.5
// 2 2 0  0
// 2 2 1  0
// 2 2 2  1
double medelian_factors[3][3][3] = {{{1,0,0},{0.5,0.5,0},{0,1,0}},{{0.5,0.5,0},{0.25,0.5,0.25},{0,0.5,0.5}},{{0,1,0},{0,0.5,0.5},{0,0,1}}};

Genotyper::Genotyper(double isize_mean, double isize_stddev, double variant_prior, int max_length_diff_split_split, int max_distance_split_split, int mapq_threshold, int bam_window_size, bool use_insert_size, bool use_split_reads, int split_read_border_size) {
	this->readgroup_params = 0;
	this->isize_mean = isize_mean;
	this->isize_stddev = isize_stddev;
	this->variant_prior = variant_prior;
	this->max_length_diff_split_split = max_length_diff_split_split;
	this->max_distance_split_split = max_distance_split_split;
	this->mapq_threshold = mapq_threshold;
	this->bam_window_size = bam_window_size;
	this->use_insert_size = use_insert_size;
	this->use_split_reads = use_split_reads;
	this->coverage_threshold = -1;
	this->masking_distance = -1;
	this->split_read_border_size = split_read_border_size;
}

Genotyper::Genotyper(const readgroup_params_map_t* readgroup_params, double variant_prior, int max_length_diff_split_split, int max_distance_split_split, int mapq_threshold, int bam_window_size, bool use_insert_size, bool use_split_reads, int split_read_border_size) {
	assert(readgroup_params != 0);
	assert(readgroup_params->size() > 0);
	this->readgroup_params = readgroup_params;
	this->isize_mean = numeric_limits<double>::quiet_NaN();
	this->variant_prior = variant_prior;
	this->max_length_diff_split_split = max_length_diff_split_split;
	this->max_distance_split_split = max_distance_split_split;
	this->mapq_threshold = mapq_threshold;
	this->bam_window_size = bam_window_size;
	this->use_insert_size = use_insert_size;
	this->use_split_reads = use_split_reads;
	this->coverage_threshold = -1;
	this->masking_distance = -1;
	// compute mean variance
	this->isize_stddev = 0;
	readgroup_params_map_t::const_iterator it = readgroup_params->begin();
	for (; it != readgroup_params->end(); ++it) {
		this->isize_stddev += it->second.stddev;
	}
	this->isize_stddev /= readgroup_params->size();
	this->split_read_border_size = split_read_border_size;
}

void Genotyper::enableMasking(int coverage_threshold, int masking_distance) {
	this->coverage_threshold = coverage_threshold;
	this->masking_distance = masking_distance;
}

std::auto_ptr<Genotyper::variation_stats_t> Genotyper::extractVariationStats(const Variation& v, BamTools::BamReader& bam_reader, OverlappingRegions* masked_regions, read_name_set_t* used_reads) const {
	int ref_id = bam_reader.GetReferenceID(v.getChromosome());
	if (ref_id < 0) {
		return auto_ptr<variation_stats_t>(0);
	}
	int length;
	int center;
	int region_start;
	int region_end;
	if (v.getType() == Variation::DELETION) {
		length = (int)v.getCoordinate2()-(int)v.getCoordinate1();
		center = ((int)v.getCoordinate1() + (int)v.getCoordinate2()) / 2;
		region_start = max(0,((int)v.getCoordinate1()) - bam_window_size);
		region_end = (int)v.getCoordinate2() + bam_window_size;
// 		cerr << "  deletion: " << v.getCoordinate1() << "-" << v.getCoordinate2() << ", center: " << center << ", bam_window_size: " << bam_window_size << ", region: " << region_start << "-" << region_end << endl;
	} else if (v.getType() == Variation::INSERTION) {
		length = (int)v.getCoordinate2();
		center = (int)v.getCoordinate1();
		region_start = max(0, center - bam_window_size);
		region_end = center + bam_window_size;
// 		cerr << "  insertion: " << v.getCoordinate1() << ", length " << v.getCoordinate2() << ", sequence: " << v.getSequence() << ", bam_window_size: " << bam_window_size << ", region: " << region_start << "-" << region_end << endl;
	} else {
		return auto_ptr<variation_stats_t>(0);
	}
	auto_ptr<variation_stats_t> result(new variation_stats_t());
	vector<BamHelper::aln_pair_t> aln_pairs;
	BamHelper::readRegion(bam_reader, ref_id, region_start, region_end, &aln_pairs);
	for (size_t i=0; i<aln_pairs.size(); ++i) {
		const BamTools::BamAlignment& aln1 = *aln_pairs[i].first;
		const BamTools::BamAlignment& aln2 = *aln_pairs[i].second;
// 		cerr << "-----------------------------------------------" << endl;
// 		cerr << "  considering read " << aln1.Name << endl;
		if (used_reads != 0) {
			// check whether the read has been used before
			read_name_set_t::const_iterator it = used_reads->find(aln1.Name);
			if (it != used_reads->end()) {
// 				cerr << "  ... has been used before, skipping." << endl;
				continue;
			}
		}
		bool using_this_read = false;
		// Look for split-read type evidence
		if (use_split_reads) {
			// Does one of the alignments cover center position?
			for (size_t j=0; j<=1; ++j) {
// 				cerr << "    checking read " << j << " for split alignments" << endl;
				const BamTools::BamAlignment& aln = (j==0)?aln1:aln2;
				if ((mapq_threshold > 0) && (aln.MapQuality < mapq_threshold)) {
// 					cerr << "      ... failed MAPQ threshold, skipping." << endl;
					continue;
				}
				if ((aln.Position<center) && (center<aln.GetEndPosition())) {
					bool is_split = false;
					bool is_regular = false;
					double p_wrong = max(0.05, pow(10.0,-((double)aln.MapQuality)/10.0));
					// result->split_read_probabilies.push_back(read_error_probabilies_t(p_wrong/2.0, p_wrong/2.0));
					using_this_read = true;
					auto_ptr<vector<Variation> > aln_vars = BamHelper::variationsFromAlignment(bam_reader.GetReferenceData(), aln);
					if (aln_vars->size() == 0) {
						// TODO: only treating reads with no indels at all as "regular" might be too strict in some settings.
						// for insertions, additionally check whether overlap is "sufficient":
						if (!((v.getType() == Variation::INSERTION) && 
							((std::abs(((int)v.getCoordinate1())-aln.Position) < split_read_border_size) || (std::abs(((int)v.getCoordinate1())-aln.GetEndPosition()) < split_read_border_size)))) {
							is_regular = true;
						} else {
// 							cerr << "      ... omitting, breakpoint too close to read border." << endl;
						}
					} else {
						vector<Variation>::const_iterator var_it = aln_vars->begin();
						for (; var_it != aln_vars->end(); ++var_it) {
							const Variation& v2 = *var_it;
							if (v.getType() != v2.getType()) continue;
							int length2, center2;
							if (var_it->getType() == Variation::DELETION) {
								length2 = (int)v2.getCoordinate2() - (int)v2.getCoordinate1();
								center2 = ((int)v2.getCoordinate1() + (int)v2.getCoordinate2()) / 2;
							} else if (var_it->getType() == Variation::INSERTION) {
								length2 = (int)v2.getCoordinate2();
								center2 = (int)v2.getCoordinate1();
							} else {
								continue;
							}
							if (abs(length-length2) > max_length_diff_split_split) continue;
							if (abs(center-center2) > max_distance_split_split) continue;
							// Check whether breakpoints are too close to alignment start/end
							if ((v2.leftBreakpoint() - aln.Position < split_read_border_size) ||
								(aln.GetEndPosition() - v2.rightBreakpoint() < split_read_border_size)) {
								break;
							}
							// TODO: in case of insertions: test for sequence identity?!
							is_split = true;
							break;
						}
					} 
					assert(!(is_regular && is_split));
					if (is_split) {
// 						cerr << "      --> using it as evidence for ALT (split alignment)" << endl;
						result->split_evidence.alt_support += 1;
						result->read_gt_probabilities.push_back(GenotypeDistribution(1.0/3.0*p_wrong, 1.0/3.0, 2.0/3.0-1.0/3.0*p_wrong));
					}
					if (is_regular) {
// 						cerr << "      --> using it as evidence for REF (regular alignment)" << endl;
						result->split_evidence.ref_support += 1;
						result->read_gt_probabilities.push_back(GenotypeDistribution(2.0/3.0-1.0/3.0*p_wrong, 1.0/3.0, 1.0/3.0*p_wrong));
					}
				} else {
// 					cerr << "      --> no overlap" << endl;
				}
			}
		}
		// Look for insert-size type evidence
		if (use_insert_size) {
			// Does internal segment of alignment pair cover deletion?
			if ((mapq_threshold == 0) || ((aln1.MapQuality >= mapq_threshold) && (aln2.MapQuality >= mapq_threshold))) {
				double p_wrong1 = pow(10.0,-((double)aln1.MapQuality)/10.0);
				double p_wrong2 = pow(10.0,-((double)aln2.MapQuality)/10.0);
				// compute probability that at least one of the two alignments is wrong
				double p_wrong = max(0.05,p_wrong1*p_wrong2 + (1.0-p_wrong1)*p_wrong2 + p_wrong1*(1.0-p_wrong2));
				const BamTools::BamAlignment& left = (aln1.Position<=aln2.Position)?aln1:aln2;
				const BamTools::BamAlignment& right = (aln1.Position<=aln2.Position)?aln2:aln1;
				// Check whether there is soft-clipping either at the right end of the left read 
				// or at the left end of the right read (i.e. on the "inside").
				if (((left.CigarData[left.CigarData.size()-1].Type == 'S') && (left.CigarData[left.CigarData.size()-1].Length > 2)) ||
					((right.CigarData[0].Type == 'S') && (right.CigarData[0].Length > 2))) {
// 					cerr << "    soft-clipped on the inside: skipping pair." << endl;
					continue;
				}
				// Does internal segment overlap deletion center? And if so, does it support the deletion?
				int insert_size = right.Position - left.GetEndPosition();
// 				cerr << "    insert_size: " << insert_size << ", left.GetEndPosition(): " << left.GetEndPosition() << ", center: " << center << ", right.Position: " << right.Position << endl;
				if ((insert_size > 0) && (left.GetEndPosition() <= center) && (center <= right.Position)) {
					// Determine isize_mean to be used for this alignment
					double isize_mean = this->isize_mean;
					double isize_stddev = this->isize_stddev;
					string rg = "";
					if (readgroup_params != 0) {
						if (!aln1.GetTag("RG", rg)) {
							ostringstream oss;
							oss << "Using read group specific insert size means, but RG tag missing in read " << aln1.Name;
							throw std::runtime_error(oss.str());
						}
						readgroup_params_map_t::const_iterator it = readgroup_params->find(rg);
						if (it == readgroup_params->end()) {
							ostringstream oss;
							oss << "Using read group specific insert size means, but read group \"" << rg << "\" of read \"" << aln1.Name << "\" is unknown.\"";
							throw std::runtime_error(oss.str());
						}
						isize_mean = it->second.mean;
						isize_stddev = it->second.stddev;
// 						cerr << "   found read-group specific distribution: read group " << rg << ", mean: " << isize_mean << endl;
					}
					// is deletion/insertion fully included in internal segment?
					bool includes_indel;
					// does length of internal segment looks more like a deletion than not?
					bool length_supports_indel;
					if (v.getType() == Variation::DELETION) {
						includes_indel = (left.GetEndPosition() < (int)v.getCoordinate1()) && ((int)v.getCoordinate2() <= right.Position);
						length_supports_indel = (insert_size - isize_mean >= length/2.0);
					} else if (v.getType() == Variation::INSERTION) {
						includes_indel = (left.GetEndPosition() < center) && (center <= right.Position);
						length_supports_indel = (insert_size - isize_mean <= -length/2.0);
					} else {
						assert(false);
					}
					if (length_supports_indel) {
						if (includes_indel) {
							result->insert_evidence.alt_support += 1;
// 							cerr << "   supporting ALT: diff = " << (insert_size - isize_mean) << endl;
						}
					} else {
						result->insert_evidence.ref_support += 1;
// 						cerr << "   supporting REF: diff = " << (insert_size - isize_mean) << endl;
					}
					using_this_read = true;
					boost::math::normal_distribution<> normal(0.0, isize_stddev);
					double f_hom;
					if (v.getType() == Variation::DELETION) {
						f_hom = boost::math::pdf(normal, insert_size-isize_mean-length);
					} else if (v.getType() == Variation::INSERTION) {
						f_hom = boost::math::pdf(normal, insert_size-isize_mean+length);
					} else {
						assert(false);
					}
					double f_abs = boost::math::pdf(normal, insert_size-isize_mean);
					double Z = 3.0/2.0 * (f_abs + f_hom);
					result->read_gt_probabilities.push_back(GenotypeDistribution((1.0-p_wrong)*f_abs/Z + 1.0/3.0*p_wrong, 1.0/3.0, (1.0-p_wrong)*f_hom/Z + 1.0/3.0*p_wrong));
				} else {
// 					cerr << "   ... internal segment does not contain center point" << endl;
				}
			} else {
// 				cerr << "   ... failed MAPQ threshold" << endl;
			}
		}
		if (using_this_read && (used_reads != 0)) {
			used_reads->insert(aln1.Name);
		}
	}
	for (size_t i=0; i<aln_pairs.size(); ++i) {
		delete aln_pairs[i].first;
		delete aln_pairs[i].second;
	}
	if (coverage_threshold >= 0) {
		if ((result->split_evidence.coverage() > coverage_threshold) || (result->insert_evidence.coverage() > coverage_threshold)) {
			if ((masked_regions != 0) && (masking_distance >= 0)) {
				if (v.getType() == Variation::DELETION) {
					masked_regions->add(ref_id, v.getCoordinate1() - masking_distance, v.getCoordinate2() + masking_distance - 1);
				} else if (v.getType() == Variation::INSERTION) {
					masked_regions->add(ref_id, v.getCoordinate1() - masking_distance, v.getCoordinate1() + masking_distance - 1);
				} else {
					assert(false);
				}
			}
			return auto_ptr<variation_stats_t>(0);
		}
	}
	return result;
}

GenotypeDistribution Genotyper::compute_genotype(int support, const std::vector<read_error_probabilies_t>& read_probabilities) const {
	SumOfBinomials absent_dist, hetero_dist, homo_dist;
	assert(support <= read_probabilities.size());
	for (size_t i = 0; i < read_probabilities.size(); ++i) {
		const read_error_probabilies_t& e = read_probabilities[i];
		absent_dist.add(e.p_fp, 1);
		hetero_dist.add(0.5*e.p_fp + 0.5*(1.0 - e.p_fn), 1);
		homo_dist.add(1.0 - e.p_fn, 1);
	}
	double no_variant_prob = absent_dist.probability(support);
	double hetero_prob = hetero_dist.probability(support);
	double homozyguous_prob = homo_dist.probability(support);
	double p_sum = no_variant_prob + homozyguous_prob + hetero_prob;
	return GenotypeDistribution(no_variant_prob/p_sum, hetero_prob/p_sum, homozyguous_prob/p_sum);
}

std::auto_ptr<GenotypeDistribution> Genotyper::computeGenotype(const Variation& v, const variation_stats_t& stats, GenotypeDistribution* raw_genotype) const {
//	p = max(p, 0.01);
	auto_ptr<GenotypeDistribution> result(new GenotypeDistribution(1.0-variant_prior, variant_prior/2.0, variant_prior/2.0));
	if (raw_genotype != 0) {
		// start with uniform distribution
		*raw_genotype = GenotypeDistribution();
	}
// 	cerr << "  prior: " << (*result) << endl;
	for (size_t i=0; i<stats.read_gt_probabilities.size(); ++i) {
		*result = *result * stats.read_gt_probabilities[i];
		if (raw_genotype != 0) {
			*raw_genotype = *raw_genotype * stats.read_gt_probabilities[i];
		}
	}
// 	cerr << "  result: " << (*result) << " --> " << result->likeliestGenotypeString() << endl;
	return result;
}

auto_ptr<Genotyper::trio_genotype_t> Genotyper::genotypeTrio(const GenotypeDistribution& mother, const GenotypeDistribution& father, const GenotypeDistribution& child, double denovo_threshold) {
	auto_ptr<trio_genotype_t> result(new trio_genotype_t());
	if ((mother.present() < denovo_threshold) && (father.present() < denovo_threshold) && (child.notPresent() < denovo_threshold)) {
		result->mother = GenotypeDistribution::ABSENT;
		result->father = GenotypeDistribution::ABSENT;
		result->child = GenotypeDistribution::HETEROZYGOUS;
		result->mother_posterior = mother;
		result->father_posterior = father;
		result->child_posterior = child;
		return result;
	}
	result->mother_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	result->father_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	result->child_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	double best_p = 0.0;
	for (int m = 0; m <= 2; ++m) {
		GenotypeDistribution::genotype_t m_genotype = (GenotypeDistribution::genotype_t)m;
		for (int f = 0; f <= 2; ++f) {
			GenotypeDistribution::genotype_t f_genotype = (GenotypeDistribution::genotype_t)f;
			for (int c = 0; c <= 2; ++c) {
				GenotypeDistribution::genotype_t c_genotype = (GenotypeDistribution::genotype_t)c;
				double p = mother.probability(m_genotype) * father.probability(f_genotype) * child.probability(c_genotype) * medelian_factors[m][f][c];
				result->mother_posterior.probability(m_genotype) += p;
				result->father_posterior.probability(f_genotype) += p;
				result->child_posterior.probability(c_genotype) += p;
				if (p > best_p) {
					best_p = p;
					result->mother = (GenotypeDistribution::genotype_t)m;
					result->father = (GenotypeDistribution::genotype_t)f;
					result->child = (GenotypeDistribution::genotype_t)c;
				}
			}
		}
	}
	result->mother_posterior.normalize();
	result->father_posterior.normalize();
	result->child_posterior.normalize();
	return result;
}

std::auto_ptr<Genotyper::quartet_genotype_t> Genotyper::genotypeQuartet(const GenotypeDistribution& mother, const GenotypeDistribution& father, const GenotypeDistribution& child1, const GenotypeDistribution& child2, double denovo_threshold, bool monozygotic) {
	auto_ptr<quartet_genotype_t> result(new quartet_genotype_t());
	if ((mother.present() < denovo_threshold) && (father.present() < denovo_threshold) && (child1.notPresent() < denovo_threshold) && (child2.present() < denovo_threshold)) {
		result->mother = GenotypeDistribution::ABSENT;
		result->father = GenotypeDistribution::ABSENT;
		result->child1 = GenotypeDistribution::HETEROZYGOUS;
		result->child2 = GenotypeDistribution::ABSENT;
		result->mother_posterior = mother;
		result->father_posterior = father;
		result->child1_posterior = child1;
		result->child2_posterior = child2;
		return result;
	}
	if ((mother.present() < denovo_threshold) && (father.present() < denovo_threshold) && (child1.present() < denovo_threshold) && (child2.notPresent() < denovo_threshold)) {
		result->mother = GenotypeDistribution::ABSENT;
		result->father = GenotypeDistribution::ABSENT;
		result->child1 = GenotypeDistribution::ABSENT;
		result->child2 = GenotypeDistribution::HETEROZYGOUS;
		result->mother_posterior = mother;
		result->father_posterior = father;
		result->child1_posterior = child1;
		result->child2_posterior = child2;
		return result;
	}
	result->mother_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	result->father_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	result->child1_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	result->child2_posterior = GenotypeDistribution(0.0, 0.0, 0.0);
	double best_p = 0.0;
	for (int m = 0; m <= 2; ++m) {
		GenotypeDistribution::genotype_t m_genotype = (GenotypeDistribution::genotype_t)m;
		for (int f = 0; f <= 2; ++f) {
			GenotypeDistribution::genotype_t f_genotype = (GenotypeDistribution::genotype_t)f;
			for (int c1 = 0; c1 <= 2; ++c1) {
				GenotypeDistribution::genotype_t c1_genotype = (GenotypeDistribution::genotype_t)c1;
				for (int c2 = 0; c2 <= 2; ++c2) {
					GenotypeDistribution::genotype_t c2_genotype = (GenotypeDistribution::genotype_t)c2;
					if (monozygotic && (c1 != c2)) continue;
					double p = mother.probability(m_genotype) * father.probability(f_genotype) * child1.probability(c1_genotype) *
						child2.probability(c2_genotype) * medelian_factors[m][f][c1] * medelian_factors[m][f][c2];
					result->mother_posterior.probability(m_genotype) += p;
					result->father_posterior.probability(f_genotype) += p;
					result->child1_posterior.probability(c1_genotype) += p;
					result->child2_posterior.probability(c2_genotype) += p;
					if (p > best_p) {
						best_p = p;
						result->mother = (GenotypeDistribution::genotype_t)m;
						result->father = (GenotypeDistribution::genotype_t)f;
						result->child1 = (GenotypeDistribution::genotype_t)c1;
						result->child2 = (GenotypeDistribution::genotype_t)c2;
					}
				}
			}
		}
	}
	result->mother_posterior.normalize();
	result->father_posterior.normalize();
	result->child1_posterior.normalize();
	result->child2_posterior.normalize();
	return result;
}
